import carla
import networkx as nx
import math
import enum
import heapq
import numpy as np

class MotionModel(enum.Enum):
    STOP = 0
    RIGHT = 1
    FORWARD = 2
    LEFT = 3
    RIGHTFORWARD30 = 4
    RIGHTFORWARD45 = 5
    RIGHTFORWARD60 = 6
    LEFTFORWARD30 = 7
    LEFTFORWARD45 = 8
    LEFTFORWARD60 = 9


class Node:
    def __init__(self, x, y, cost, parent, motion_state):
        self.x = x
        self.y = y
        self.cost = cost
        self.parent = parent
        self.motion_state = motion_state

    def __eq__(self, other):
        return self.x == other.x and self.y == other.y

    def __lt__(self, other):
        return self.cost < other.cost


class AStarPlanner:
    def __init__(self, world, ego_vehicle, sampling_resolution):
        self.world = world
        self.map = world.get_map()
        self.ego_vehicle = ego_vehicle
        self.sampling_resolution = sampling_resolution
        self.motion_dict_dx_dy_cost = {
            # x > 0 forward, y > 0 right
            MotionModel.RIGHT:          [0, 2, 2],
            MotionModel.RIGHTFORWARD60: [1, 2, np.sqrt(5)],
            MotionModel.RIGHTFORWARD45: [2, 2, np.sqrt(8)],
            MotionModel.RIGHTFORWARD30: [2, 1, np.sqrt(5)],
            MotionModel.FORWARD:        [2, 0, 2],
            MotionModel.LEFTFORWARD30:  [2, -1, np.sqrt(5)],
            MotionModel.LEFTFORWARD45:  [2, -2, np.sqrt(8)],
            MotionModel.LEFTFORWARD60:  [1, -2, np.sqrt(5)],
            MotionModel.LEFT:           [0, -2, 2],
        }
        self.next_motion_map = {
            MotionModel.RIGHT:          [MotionModel.RIGHT, MotionModel.RIGHTFORWARD60],
            MotionModel.RIGHTFORWARD60: [MotionModel.RIGHTFORWARD60, MotionModel.RIGHT, MotionModel.RIGHTFORWARD45],
            MotionModel.RIGHTFORWARD45: [MotionModel.RIGHTFORWARD45, MotionModel.RIGHTFORWARD60, MotionModel.RIGHTFORWARD30],
            MotionModel.RIGHTFORWARD30: [MotionModel.RIGHTFORWARD30, MotionModel.RIGHTFORWARD45, MotionModel.FORWARD],
            MotionModel.FORWARD:        [MotionModel.FORWARD, MotionModel.LEFTFORWARD30, MotionModel.RIGHTFORWARD30],
            MotionModel.LEFTFORWARD30:  [MotionModel.LEFTFORWARD30, MotionModel.FORWARD, MotionModel.LEFTFORWARD45],
            MotionModel.LEFTFORWARD45:  [MotionModel.LEFTFORWARD45, MotionModel.LEFTFORWARD30, MotionModel.LEFTFORWARD60],
            MotionModel.LEFTFORWARD60:  [MotionModel.LEFTFORWARD60, MotionModel.LEFTFORWARD45, MotionModel.LEFT],
            MotionModel.LEFT:           [MotionModel.LEFT, MotionModel.LEFTFORWARD60],
            MotionModel.STOP:           [MotionModel.LEFT,
                                         MotionModel.LEFTFORWARD60, MotionModel.LEFTFORWARD45, MotionModel.LEFTFORWARD30,
                                         MotionModel.FORWARD,
                                         MotionModel.RIGHTFORWARD30, MotionModel.RIGHTFORWARD45, MotionModel.RIGHTFORWARD60,
                                         MotionModel.RIGHT],
        }
        self.ego_vehicle_length = 2.5  # meters
        self.ego_vehicle_width = 1.5   # meters
        self.safety_distance = 0.5     # meters
        self.search_radius = 70.0      # meters

    def trace_route(self, route):
        planned_route = []
        obstacles = self.get_dynamic_obstacles()

        for i in range(len(route) - 1):
            start_location = route[i]
            goal_location = route[i + 1]
            path = self.a_star_search(start_location, goal_location, obstacles)
            if path is None:
                print("No path found between {} and {}".format(start_location, goal_location))
                path = []
            planned_route.extend(path)
        return planned_route

    def a_star_search(self, start_location, goal_location, obstacles):
        open_set = []
        closed_set = set()

        start = self.world_to_vehicle(start_location)
        start_node = Node(start.x, start.y, 0, None, MotionModel.STOP)

        goal = self.world_to_vehicle(goal_location)
        goal_node = Node(goal.x, goal.y, 0, None, None)

        heapq.heappush(open_set, (0 + self.heuristic(start_node, goal_node), start_node))

        reference_point = start  # Center of the search boundary

        while open_set:
            current_node = heapq.heappop(open_set)[1]

            if (current_node.x, current_node.y) in closed_set:
                continue

            closed_set.add((current_node.x, current_node.y))

            if self.is_goal_reached(current_node, goal_node):
                path = self.reconstruct_path(current_node)
                return path

            neighbors = self.get_neighbors(current_node, reference_point)
            for neighbor in neighbors:
                if not self.is_node_drivable(neighbor.x, neighbor.y):
                    continue
                if self.is_collision(neighbor.x, neighbor.y, obstacles):
                    continue
                if (neighbor.x, neighbor.y) in closed_set:
                    continue
                neighbor_cost = neighbor.cost + self.heuristic(neighbor, goal_node)
                heapq.heappush(open_set, (neighbor_cost, neighbor))
        return None  # No path found

    def heuristic(self, node, goal_node):
        dx = node.x - goal_node.x
        dy = node.y - goal_node.y
        return math.hypot(dx, dy)

    def is_goal_reached(self, node, goal_node):
        distance = math.hypot(node.x - goal_node.x, node.y - goal_node.y)
        if distance < self.sampling_resolution:
            return True
        else:
            return False

    def reconstruct_path(self, node):
        path = []
        while node.parent is not None:
            location = carla.Location(x=node.x, y=node.y)
            location = self.vehicle_to_world(location)
            waypoint = self.map.get_waypoint(location, project_to_road=False)
            road_option = None  # This can be determined based on motion_state or other info
            if waypoint is not None:
                path.append((waypoint, road_option))
            node = node.parent
        path.reverse()
        return path

    def get_neighbors(self, node, reference_point):
        neighbors = []
        possible_motions = self.next_motion_map.get(node.motion_state, [])
        # possible_motions = self.motion_dict_dx_dy_cost.keys()
        for motion in possible_motions:
            dx, dy, cost = self.motion_dict_dx_dy_cost[motion]
            x = node.x + dx * self.sampling_resolution
            y = node.y + dy * self.sampling_resolution

            # Boundary check: Ensure the neighbor is within 70 meters of the reference point
            distance_from_reference = math.hypot(x - reference_point.x, y - reference_point.y)
            if distance_from_reference > self.search_radius:
                continue  # Skip this neighbor

            new_cost = node.cost + cost * self.sampling_resolution
            neighbor = Node(x, y, new_cost, node, motion)
            neighbors.append(neighbor)
        return neighbors

    def is_node_drivable(self, x, y):
        location = carla.Location(x=x, y=y)
        location = self.vehicle_to_world(location)
        waypoint = self.map.get_waypoint(location, project_to_road=False)
        if waypoint is None:
            return False
        else:
            lane_type = waypoint.lane_type
            if lane_type in [carla.LaneType.Driving,]:
                return True
            else:
                return False

    def is_collision(self, x, y, obstacles):
        node_location = carla.Location(x=x, y=y)
        node_location = self.vehicle_to_world(node_location)
        for obstacle in obstacles:
            obstacle_location = obstacle['location']
            distance = node_location.distance(obstacle_location)
            obstacle_bbox = obstacle['bbox']
            obstacle_extent = max(obstacle_bbox.extent.x, obstacle_bbox.extent.y)
            if distance < obstacle_extent + self.safety_distance:
                return True
        return False

    def get_dynamic_obstacles(self):
        vehicles = self.world.get_actors().filter('vehicle.*')
        walkers = self.world.get_actors().filter('walker.*')
        obstacles = []
        for vehicle in vehicles:
            if vehicle.id == self.ego_vehicle.id:
                continue
            bbox = vehicle.bounding_box
            location = vehicle.get_location()
            obstacles.append({'type': 'vehicle', 'location': location, 'bbox': bbox})
        for walker in walkers:
            bbox = walker.bounding_box
            location = walker.get_location()
            obstacles.append({'type': 'walker', 'location': location, 'bbox': bbox})
        return obstacles

    def world_to_vehicle(self, location):
        inverse_transform = self.ego_vehicle.get_transform().get_inverse_matrix()
        location_world = np.array([location.x, location.y, 0, 1])
        location_ego = np.dot(inverse_transform, location_world)
        return carla.Location(x=location_ego[0], y=location_ego[1])
    
    def vehicle_to_world(self, location):
        transform = self.ego_vehicle.get_transform().get_matrix()
        location_ego = np.array([location.x, location.y, 0, 1])
        location_world = np.dot(transform, location_ego)
        return carla.Location(x=location_world[0], y=location_world[1])